{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9f422bf4",
   "metadata": {},
   "source": [
    "# TalkNet Training\n",
    "\n",
    "This notebook is designed to provide a guide on how to train TalkNet as part of the TTS pipeline. It contains the following two sections:\n",
    "  1. **Introduction**: TalkNet in NeMo\n",
    "  2. **Preprocessing**: how to prepare data for Talknet \n",
    "  3. **Training**: example of TalkNet training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d38c43a0",
   "metadata": {},
   "source": [
    "# License\n",
    "\n",
    "> Copyright 2020 NVIDIA. All Rights Reserved.\n",
    "> \n",
    "> Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "> you may not use this file except in compliance with the License.\n",
    "> You may obtain a copy of the License at\n",
    "> \n",
    ">     http://www.apache.org/licenses/LICENSE-2.0\n",
    "> \n",
    "> Unless required by applicable law or agreed to in writing, software\n",
    "> distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "> See the License for the specific language governing permissions and\n",
    "> limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39270500",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies# .\n",
    "\"\"\"\n",
    "# # If you're using Colab and not running locally, uncomment and run this cell.\n",
    "# !apt-get install sox libsndfile1 ffmpeg\n",
    "# !pip install wget unidecode pysptk\n",
    "# BRANCH = 'main'\n",
    "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "198d92c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pysptk\n",
    "import json\n",
    "import nemo\n",
    "import torch\n",
    "import librosa\n",
    "import numpy as np\n",
    "\n",
    "from pysptk import sptk\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2831436",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "TalkNet is a neural network that converts text characters into a mel spectrogram. For more details about model, please refer to Nvidia's TalkNet Model Card, or the original [paper](https://arxiv.org/abs/2104.08189).\n",
    "\n",
    "TalkNet like most NeMo models is defined as a LightningModule, allowing for easy training via PyTorch Lightning, and parameterized by a configuration, currently defined via a yaml file and loading using Hydra.\n",
    "\n",
    "Let's take a look using NeMo's pretrained model and how to use it to generate spectrograms."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5e6d456",
   "metadata": {},
   "source": [
    "## NOTE\n",
    "\n",
    "<span style=\"color:red\">TalkNet loading does not currently work in main and above. Please revert to r1.2.0 if interested in testing TalkNet inference with our pre-trained model.</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea69725",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Load the TalkNetSpectModel\n",
    "# from nemo.collections.tts.models import TalkNetSpectModel\n",
    "# from nemo.collections.tts.models.base import SpectrogramGenerator\n",
    "\n",
    "# # Let's see what pretrained models are available\n",
    "# print(TalkNetSpectModel.list_available_models())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce0a10e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # We can load the pre-trained model as follows\n",
    "# pretrained_model = \"tts_en_talknet\"\n",
    "# model = TalkNetSpectModel.from_pretrained(pretrained_model)\n",
    "\n",
    "# # Load and attach durs and pitch predictors\n",
    "# from nemo.collections.tts.models import TalkNetPitchModel\n",
    "# pitch_model = TalkNetPitchModel.from_pretrained(pretrained_model)\n",
    "# from nemo.collections.tts.models import TalkNetDursModel\n",
    "# durs_model = TalkNetDursModel.from_pretrained(pretrained_model)\n",
    "# model.add_module('_pitch_model', pitch_model)\n",
    "# model.add_module('_durs_model', durs_model)\n",
    "\n",
    "# model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4407cec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # TalkNet is a SpectrogramGenerator\n",
    "# assert isinstance(model, SpectrogramGenerator)\n",
    "\n",
    "# # SpectrogramGenerators in NeMo have two helper functions:\n",
    "# #   1. parse(text: str, **kwargs) which takes an English string and produces a token tensor\n",
    "# #   2. generate_spectrogram(tokens: 'torch.tensor', **kwargs) which takes the token tensor and generates a spectrogram\n",
    "# # Let's try it out\n",
    "# tokens = model.parse(text=\"Hey, this produces speech!\")\n",
    "# spectrogram = model.generate_spectrogram(tokens=tokens)\n",
    "\n",
    "# # Now we can visualize the generated spectrogram\n",
    "# # If we want to generate speech, we have to use a vocoder in conjunction to a spectrogram generator.\n",
    "# # Refer to the TTS Inference notebook on how to convert spectrograms to speech.\n",
    "# from matplotlib.pyplot import imshow\n",
    "# from matplotlib import pyplot as plt\n",
    "# %matplotlib inline\n",
    "# imshow(spectrogram.cpu().detach().numpy()[0,...], origin=\"lower\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fb144c",
   "metadata": {},
   "source": [
    "# Preprocessing\n",
    "\n",
    "Now that we looked at the TalkNet model, let's see how to prepare all data for training it. \n",
    "\n",
    "Firstly, let's download all necessary training scripts and configs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f75ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/talknet_durs.py\n",
    "!wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/talknet_pitch.py\n",
    "!wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/talknet_spect.py\n",
    "\n",
    "!mkdir -p conf && cd conf \\\n",
    "&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/conf/talknet-durs.yaml \\\n",
    "&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/conf/talknet-pitch.yaml \\\n",
    "&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/tts/conf/talknet-spect.yaml \\\n",
    "&& cd .."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "224c7ed3",
   "metadata": {},
   "source": [
    "We will show example of preprocessing and training using small part of AN4 dataset. It consists of recordings of people spelling out addresses, names, telephone numbers, etc., one letter or number at a time, as well as their corresponding transcripts. Let's download data and prepare manifests.\n",
    "\n",
    "*NOTE: The sample data is not enough data to properly train a TalkNet. This will not result in a trained TalkNet and is used to just as example.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec473890",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://github.com/NVIDIA/NeMo/releases/download/v0.11.0/test_data.tar.gz && mkdir -p tests/data && tar xzf test_data.tar.gz -C tests/data\n",
    "\n",
    "# Just like ASR, the TalkNet require .json files to define the training and validation data.\n",
    "!cat tests/data/asr/an4_val.json\n",
    "!cat tests/data/asr/an4_train.json tests/data/asr/an4_val.json > tests/data/asr/an4_all.json "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9f142c1",
   "metadata": {},
   "source": [
    "## Extracting phoneme ground truth durations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1063922",
   "metadata": {},
   "source": [
    "As a part of whole model, you will need to train duration predictor. We will extract phoneme ground truth durations from ASR model (QuartzNet5x5, trained on LibriTTS) using forward-backward algorithm (see paper for details). Let's download pretrained ASR model and define auxiliary functions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0deff219",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import EncDecCTCModel\n",
    "asr_model = EncDecCTCModel.from_pretrained(model_name=\"asr_talknet_aligner\").cpu().eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d21797",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_extractor(tokens, log_probs, blank):\n",
    "    \"\"\"Computes states f and p.\"\"\"\n",
    "    n, m = len(tokens), log_probs.shape[0]\n",
    "    # `f[s, t]` -- max sum of log probs for `s` first codes\n",
    "    # with `t` first timesteps with ending in `tokens[s]`.\n",
    "    f = np.empty((n + 1, m + 1), dtype=float)\n",
    "    f.fill(-(10 ** 9))\n",
    "    p = np.empty((n + 1, m + 1), dtype=int)\n",
    "    f[0, 0] = 0.0  # Start\n",
    "    for s in range(1, n + 1):\n",
    "        c = tokens[s - 1]\n",
    "        for t in range((s + 1) // 2, m + 1):\n",
    "            f[s, t] = log_probs[t - 1, c]\n",
    "            # Option #1: prev char is equal to current one.\n",
    "            if s == 1 or c == blank or c == tokens[s - 3]:\n",
    "                options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]\n",
    "            else:  # Is not equal to current one.\n",
    "                options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]\n",
    "            f[s, t] += np.max(options)\n",
    "            p[s, t] = np.argmax(options)\n",
    "    return f, p\n",
    "\n",
    "\n",
    "def backward_extractor(f, p):\n",
    "    \"\"\"Computes durs from f and p.\"\"\"\n",
    "    n, m = f.shape\n",
    "    n -= 1\n",
    "    m -= 1\n",
    "    durs = np.zeros(n, dtype=int)\n",
    "    if f[-1, -1] >= f[-2, -1]:\n",
    "        s, t = n, m\n",
    "    else:\n",
    "        s, t = n - 1, m\n",
    "    while s > 0:\n",
    "        durs[s - 1] += 1\n",
    "        s -= p[s, t]\n",
    "        t -= 1\n",
    "    assert durs.shape[0] == n\n",
    "    assert np.sum(durs) == m\n",
    "    assert np.all(durs[1::2] > 0)\n",
    "    return durs\n",
    "\n",
    "def preprocess_tokens(tokens, blank):\n",
    "    new_tokens = [blank]\n",
    "    for c in tokens:\n",
    "        new_tokens.extend([c, blank])\n",
    "    tokens = new_tokens\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a6c9b28",
   "metadata": {},
   "source": [
    "Now we can run extraction and save result. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9171952c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = {\n",
    "    'manifest_filepath': \"tests/data/asr/an4_all.json\",\n",
    "    'sample_rate': 16000,\n",
    "    'labels': asr_model.decoder.vocabulary,\n",
    "    'batch_size': 1,\n",
    "}\n",
    "\n",
    "parser = nemo.collections.asr.data.audio_to_text.AudioToCharWithDursF0Dataset.make_vocab(\n",
    "    notation='phonemes', punct=True, spaces=True, stresses=False, add_blank_at=\"last\"\n",
    ")\n",
    "\n",
    "dataset = nemo.collections.asr.data.audio_to_text._AudioTextDataset(\n",
    "    manifest_filepath=data_config['manifest_filepath'], sample_rate=data_config['sample_rate'], parser=parser,\n",
    ")\n",
    "\n",
    "dl = torch.utils.data.DataLoader(\n",
    "    dataset=dataset, batch_size=data_config['batch_size'], collate_fn=dataset.collate_fn, shuffle=False,\n",
    ")\n",
    "\n",
    "blank_id = asr_model.decoder.num_classes_with_blank - 1\n",
    "\n",
    "dur_data = {}\n",
    "for sample_idx, test_sample in tqdm(enumerate(dl), total=len(dl)):\n",
    "    log_probs, _, greedy_predictions = asr_model(\n",
    "        input_signal=test_sample[0], input_signal_length=test_sample[1]\n",
    "    )\n",
    "\n",
    "    log_probs = log_probs[0].cpu().detach().numpy()\n",
    "    seq_ids = test_sample[2][0].cpu().detach().numpy()\n",
    "\n",
    "    target_tokens = preprocess_tokens(seq_ids, blank_id)\n",
    "\n",
    "    f, p = forward_extractor(target_tokens, log_probs, blank_id)\n",
    "    durs = backward_extractor(f, p)\n",
    "\n",
    "    dur_key = Path(dl.dataset.manifest_processor.collection[sample_idx].audio_file).stem\n",
    "    dur_data[dur_key] = {\n",
    "        'blanks': torch.tensor(durs[::2], dtype=torch.long).cpu().detach(), \n",
    "        'tokens': torch.tensor(durs[1::2], dtype=torch.long).cpu().detach()\n",
    "    }\n",
    "\n",
    "    del test_sample\n",
    "\n",
    "torch.save(dur_data, \"tests/data/asr/an4_durations.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "289550ca",
   "metadata": {},
   "source": [
    "## Extracting ground truth f0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44037b0e",
   "metadata": {},
   "source": [
    "The second model, that you will need to train before spectrogram generator, is pitch predictor. As labels for pitch predictor, we will use f0 from audio using `pysptk` library (see paper for details). Let's extract f0, calculate stats (mean & std) and save it all."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd0fbe8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_f0(audio_file, sample_rate=16000, hop_length=256):\n",
    "    audio, _ = librosa.load(audio_file, sample_rate)\n",
    "#     audio = torchaudio.load(audio_file)[0].squeeze().numpy()\n",
    "    f0 = sptk.swipe(audio.astype(np.float64), sample_rate, hopsize=hop_length)\n",
    "    # Hack to make f0 and mel lengths equal\n",
    "    if len(audio) % hop_length == 0:\n",
    "        f0 = np.pad(f0, pad_width=[0, 1])\n",
    "    return torch.from_numpy(f0.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3038e526",
   "metadata": {},
   "outputs": [],
   "source": [
    "f0_data = {}\n",
    "with open(\"tests/data/asr/an4_all.json\") as f:\n",
    "    for l in tqdm(f):\n",
    "        audio_path = json.loads(l)[\"audio_filepath\"]\n",
    "        f0_data[Path(audio_path).stem] = extract_f0(audio_path)\n",
    "\n",
    "# calculate f0 stats (mean & std) only for train set\n",
    "with open(\"tests/data/asr/an4_train.json\") as f:\n",
    "    train_ids = {Path(json.loads(l)[\"audio_filepath\"]).stem for l in f}\n",
    "all_f0 = torch.cat([f0[f0 >= 1e-5] for f0_id, f0 in f0_data.items() if f0_id in train_ids])\n",
    "\n",
    "F0_MEAN, F0_STD = all_f0.mean().item(), all_f0.std().item()        \n",
    "torch.save(f0_data, \"tests/data/asr/an4_f0s.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b31f38d6",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e68b286",
   "metadata": {},
   "source": [
    "Now we are ready for training our models! Let's try to train TalkNet parts consequentially."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e875eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python talknet_durs.py sample_rate=16000 \\\n",
    "train_dataset=tests/data/asr/an4_train.json \\\n",
    "validation_datasets=tests/data/asr/an4_val.json \\\n",
    "durs_file=tests/data/asr/an4_durations.pt \\\n",
    "f0_file=tests/data/asr/an4_f0s.pt \\\n",
    "trainer.max_epochs=3 \\\n",
    "trainer.accelerator=null \\\n",
    "trainer.check_val_every_n_epoch=1 \\\n",
    "model.train_ds.dataloader_params.batch_size=6 \\\n",
    "model.train_ds.dataloader_params.num_workers=0 \\\n",
    "model.validation_ds.dataloader_params.num_workers=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "924ecc0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python talknet_pitch.py sample_rate=16000 \\\n",
    "train_dataset=tests/data/asr/an4_train.json \\\n",
    "validation_datasets=tests/data/asr/an4_val.json \\\n",
    "durs_file=tests/data/asr/an4_durations.pt \\\n",
    "f0_file=tests/data/asr/an4_f0s.pt \\\n",
    "trainer.max_epochs=3 \\\n",
    "trainer.accelerator=null \\\n",
    "trainer.check_val_every_n_epoch=1 \\\n",
    "model.f0_mean={F0_MEAN} \\\n",
    "model.f0_std={F0_STD} \\\n",
    "model.train_ds.dataloader_params.batch_size=6 \\\n",
    "model.train_ds.dataloader_params.num_workers=0 \\\n",
    "model.validation_ds.dataloader_params.num_workers=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c4bc7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python talknet_spect.py sample_rate=16000 \\\n",
    "train_dataset=tests/data/asr/an4_train.json \\\n",
    "validation_datasets=tests/data/asr/an4_val.json \\\n",
    "durs_file=tests/data/asr/an4_durations.pt \\\n",
    "f0_file=tests/data/asr/an4_f0s.pt \\\n",
    "trainer.max_epochs=3 \\\n",
    "trainer.accelerator=null \\\n",
    "trainer.check_val_every_n_epoch=1 \\\n",
    "model.train_ds.dataloader_params.batch_size=6 \\\n",
    "model.train_ds.dataloader_params.num_workers=0 \\\n",
    "model.validation_ds.dataloader_params.num_workers=0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df8ac1e7",
   "metadata": {},
   "source": [
    "That's it!\n",
    "\n",
    "In order to train TalkNet for real purposes, it is highly recommended to obtain high quality speech data with the following properties:\n",
    "\n",
    "* Sampling rate of 22050Hz or higher\n",
    "* Single speaker\n",
    "* Speech should contain a variety of speech phonemes\n",
    "* Audio split into segments of 1-10 seconds\n",
    "* Audio segments should not have silence at the beginning and end\n",
    "* Audio segments should not contain long silences inside"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
